library(tidyverse)
library(brms)
library(tidybayes)
library(bayesplot)
library(cowplot)
library(scales)
library(hexbin)
library(glue)

devAskNewPage(ask = FALSE)

theme_set(theme_light())

source(glue("{params$common_dir_str}/simulation.R"))
source(glue("{params$common_dir_str}/brms_model.R"))
source(glue("{params$model_dir_str}/model_prior.R"))

print(bprior_full)
##              prior class        coef group resp   dpar nlpar bound
## 1 normal(3.8, 0.4)     b   intercept            circSD            
## 2   normal(0, 0.4)     b stimulation            circSD            
## 3  normal(0, 0.25)    sd   Intercept  subj      circSD            
## 4  normal(0, 0.25)    sd stimulation  subj      circSD            
## 5   normal(0, 1.5)     b   intercept             theta            
## 6     normal(0, 1)     b stimulation             theta            
## 7   normal(0, 0.5)    sd   Intercept  subj       theta            
## 8   normal(0, 0.5)    sd stimulation  subj       theta

load sim data

conditions <- c(0,1)

sim_datasets_fpath <- glue("{params$save_dir_str}/sim_datasets.rds")

if (file.exists(sim_datasets_fpath)){
  
  sim_datasets <- readRDS(sim_datasets_fpath)

} else { 
  
  print("simulating")
  
  nsim_datasets <- 1
  
  sim_priors <- tibble(
    sim_num = 1:nsim_datasets,
    alpha0_mu =  alpha0_mu_prior_mu,
    alpha0_sigma = alpha0_sigma_prior_sd,
    alphaD_mu = alphaD_mu_prior_mu,
    alphaD_sigma = alphaD_sigma_prior_sd,
    beta0_mu = beta0_mu_prior_mu,
    beta0_sigma = beta0_sigma_prior_sd,
    betaD_mu = betaD_mu_prior_mu,
    betaD_sigma = betaD_sigma_prior_sd,
    nsubj = params$nsubj_sim,
    nobs_per_cond = params$nobs_per_cond_sim
  )
  
  sim_datasets <- 
    sim_priors %>%
    mutate(
      # use draw_subj to sample nsubj_sim per sim using group-level parameter draws
      dataset = pmap(sim_priors, draw_subj),
      stimulation = list(stimulation = rep(conditions, each = nobs_per_cond))) %>%
    
    # first unnest dataset, expanding by nsubj_sim and copying stimulation list to each subj
    unnest(dataset) %>%
    
    # then unnest stimulation, expanding by nobs_per_cond_sim*2
    unnest(stimulation) %>%
    
    # now use likelihood to simulation observations
    mutate(
      # evaluate and delink linear model on pMem
      pMem = inv_logit(subj_beta0 + (subj_betaD * stimulation)),
      
      # evaluate and delink linear model on circSD/kappa
      k = sd2k_vec(
        pracma::deg2rad(
          exp(subj_alpha0 + (subj_alphaD * stimulation)))),
      
      # use pMem to draw a 1 or 0 for each trial
      memFlip = rbernoulli(n(), pMem),
      
      # use k to draw from vonMises for each trial
      vm_draw = rvonmises_vec(1, pi, k) - pi,
      
      # draw from unif for each trial
      unif_draw = runif(n(), -pi, pi),
      
      # assign either vm_draw or unif_draw to each trial, depending on memFlip
      obs_radian = memFlip * vm_draw + (1 - memFlip) * unif_draw,
      
      # convert to degrees
      obs_degree = obs_radian * (180/pi)
    ) %>%
    select(-c(pMem, k, memFlip, vm_draw, unif_draw)) %>%
    nest(subj_obs = c(stimulation, obs_degree, obs_radian)) %>%
    nest(dataset = c(subj, subj_alpha0, subj_alphaD, subj_beta0, subj_betaD, nobs_per_condition, subj_obs))

  
  saveRDS(sim_datasets, file = sim_datasets_fpath)

}

obs_only <- 
  sim_datasets %>% 
  unnest(dataset) %>% 
  unnest(subj_obs) %>% 
  select(c(subj, stimulation, obs_degree, error = obs_radian)) %>% 
  mutate(subj = as_factor(subj))

peek at data

obs_only %>% 
  filter(stimulation == 0) %>%
  ggplot(aes(x = obs_degree)) +
  geom_histogram(binwidth = 10, aes(y=..density..)) + 
  geom_rug() + 
  geom_density(aes(y=..density..)) +  
  facet_wrap(vars(subj), ncol = 1)

obs_only %>% 
  filter(stimulation == 1) %>%
  ggplot(aes(x = obs_degree)) +
  geom_histogram(binwidth = 10, aes(y=..density..)) + 
  geom_rug() + 
  geom_density(aes(y=..density..)) +  
  facet_wrap(vars(subj), ncol = 1)

fit brms

iter = 6000
warmup = 3000
cores = 4
chains = 4
n_post_samples = (iter - warmup) * chains

writeLines(
  make_stancode(bform_full, obs_only, family = vm_uniform_mix, prior = bprior_full, stanvars = stanvars),
  glue("{params$save_dir_str}/stan_code.txt")
)

model_fit <- brm(bform_full, obs_only, family = vm_uniform_mix, prior = bprior_full, stanvars = stanvars,
                 sample_prior = "yes",
                 warmup = warmup, iter = iter, cores = cores, chains = chains, 
                 control = list(adapt_delta = 0.99), inits = 0, 
                 file = glue("{params$save_dir_str}/sim_model_fit"))

print(model_fit)
##  Family: vm_uniform_mix 
##   Links: mu = identity; circSD = log; theta = logit; a = identity; b = identity 
## Formula: error ~ 0 
##          circSD ~ 0 + intercept + stimulation + (1 + stimulation || subj)
##          theta ~ 0 + intercept + stimulation + (1 + stimulation || subj)
##          a = -3.14
##          b = 3.14
##    Data: obs_only (Number of observations: 756) 
## Samples: 4 chains, each with iter = 6000; warmup = 3000; thin = 1;
##          total post-warmup samples = 12000
## 
## Group-Level Effects: 
## ~subj (Number of levels: 3) 
##                        Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
## sd(circSD_Intercept)       0.24      0.13     0.02     0.54 1.00     4563
## sd(circSD_stimulation)     0.24      0.15     0.01     0.57 1.00     4589
## sd(theta_Intercept)        0.58      0.31     0.04     1.22 1.00     4501
## sd(theta_stimulation)      0.42      0.30     0.01     1.10 1.00     4793
##                        Tail_ESS
## sd(circSD_Intercept)       3182
## sd(circSD_stimulation)     3935
## sd(theta_Intercept)        3730
## sd(theta_stimulation)      4158
## 
## Population-Level Effects: 
##                    Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
## circSD_intercept       3.87      0.18     3.51     4.23 1.00     6234
## circSD_stimulation    -0.19      0.22    -0.59     0.26 1.00     5891
## theta_intercept        0.47      0.48    -0.50     1.44 1.00     6273
## theta_stimulation     -0.20      0.48    -1.17     0.71 1.00     6713
##                    Tail_ESS
## circSD_intercept       7123
## circSD_stimulation     7016
## theta_intercept        6649
## theta_stimulation      7256
## 
## Family Specific Parameters: 
##                      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
## b_circSD_intercept       3.87      0.18     3.51     4.23 1.00     6234
## b_circSD_stimulation    -0.19      0.22    -0.59     0.26 1.00     5891
## b_theta_intercept        0.47      0.48    -0.50     1.44 1.00     6273
## b_theta_stimulation     -0.20      0.48    -1.17     0.71 1.00     6713
##                      Tail_ESS
## b_circSD_intercept       7123
## b_circSD_stimulation     7016
## b_theta_intercept        6649
## b_theta_stimulation      7256
## 
## Samples were drawn using sampling(NUTS). For each parameter, Eff.Sample 
## is a crude measure of effective sample size, and Rhat is the potential 
## scale reduction factor on split chains (at convergence, Rhat = 1).
sim_datasets %>% 
  unnest(dataset) %>% glimpse()
## Observations: 3
## Variables: 18
## $ sim_num            <int> 1, 1, 1
## $ alpha0_mu          <dbl> 3.8, 3.8, 3.8
## $ alpha0_sigma       <dbl> 0.25, 0.25, 0.25
## $ alphaD_mu          <dbl> 0, 0, 0
## $ alphaD_sigma       <dbl> 0.25, 0.25, 0.25
## $ beta0_mu           <dbl> 0, 0, 0
## $ beta0_sigma        <dbl> 0.5, 0.5, 0.5
## $ betaD_mu           <dbl> 0, 0, 0
## $ betaD_sigma        <dbl> 0.5, 0.5, 0.5
## $ nsubj              <int> 3, 3, 3
## $ nobs_per_cond      <int> 126, 126, 126
## $ subj               <int> 1, 2, 3
## $ subj_alpha0        <dbl> 3.762122, 3.610976, 4.194484
## $ subj_alphaD        <dbl> 0.07231436, -0.23306853, 0.41585695
## $ subj_beta0         <dbl> 0.46095897, 0.41920432, -0.08904445
## $ subj_betaD         <dbl> 0.76526089, -0.01460931, 0.79985402
## $ nobs_per_condition <int> 126, 126, 126
## $ subj_obs           <S3: vctrs_list_of> 0.00000000, 0.00000000, 0.00000…

fit check

divergences

#check neff and rhat and divergences
np <- nuts_params(model_fit)
rhat <- brms::rhat(model_fit)
neff_rat <- neff_ratio(model_fit)

np %>% 
  filter(Parameter == "divergent__") %>%
  summarise(n_div = sum(Value))
##   n_div
## 1     0

rhat

mcmc_rhat(rhat) + yaxis_text(hjust = 1) + scale_x_continuous(breaks = pretty_breaks(6))
## Scale for 'x' is already present. Adding another scale for 'x', which
## will replace the existing scale.

neff ratio

mcmc_neff(neff_rat) + yaxis_text(hjust = 1)

trace plots

mcmc_trace(as.array(model_fit$fit))

other

plot(model_fit, ask = FALSE)

plot posteriors

arrange samples

# compute summaries for plot

group_level_samples <- 
  spread_draws(model_fit, `(b|sd)_.*`, regex = TRUE) %>%
  mutate(
         # group level parameters
         circSD_pre_mean  = exp(b_circSD_intercept),
         circSD_post_mean = exp(b_circSD_intercept + b_circSD_stimulation),
         circSD_ES_mean   = circSD_post_mean - circSD_pre_mean,
         pMem_pre_mean    = inv_logit(b_theta_intercept),
         pMem_post_mean   = inv_logit(b_theta_intercept + b_theta_stimulation),
         pMem_ES_mean     = pMem_post_mean - pMem_pre_mean,
         # predicitve dist for group level parameters
         circSD_pre_pred  = exp(rnorm(n(), b_circSD_intercept, sd_subj__circSD_Intercept)),
         circSD_post_pred = exp(rnorm(n(), b_circSD_intercept, sd_subj__circSD_Intercept) + 
                                rnorm(n(), b_circSD_stimulation, sd_subj__circSD_stimulation)),
         circSD_ES_pred   = circSD_post_pred - circSD_pre_pred,
         pMem_pre_pred    = inv_logit(rnorm(n(), b_theta_intercept, sd_subj__theta_Intercept)),
         pMem_post_pred   = inv_logit(rnorm(n(), b_theta_intercept, sd_subj__theta_Intercept) +
                                    rnorm(n(), b_theta_stimulation, sd_subj__theta_stimulation)),
         pMem_ES_pred     = pMem_post_pred - pMem_pre_pred
         ) %>% 
  select(-contains("b_"), -contains("sd_subj")) %>%
  pivot_longer(-contains("."), names_to = c("param", "stat"), names_pattern = "(.*)_(.*)", values_to = "value") %>%
  pivot_wider(names_from = stat, values_from = value)
  
group_level_summary <- 
  group_level_samples %>%
  group_by(param) %>%
  median_qi(.width = c(.5, .8, .95))



circSD_subj_samples <- 
  model_fit %>% 
  spread_draws(b_circSD_intercept, b_circSD_stimulation, r_subj__circSD[subj, term]) %>%
  ungroup() %>%
  pivot_wider(names_from = term, values_from = r_subj__circSD, names_prefix = "offset_") %>%
  mutate(subj = subj,
        circSD_pre = exp(b_circSD_intercept + offset_Intercept),
        circSD_post = exp(b_circSD_intercept + offset_Intercept + b_circSD_stimulation + offset_stimulation),
        circSD_ES = circSD_post - circSD_pre) %>%
  select(-c(b_circSD_intercept, offset_Intercept, b_circSD_stimulation, offset_stimulation)) %>%
  pivot_longer(contains("circSD"), names_to = "param", values_to = "value") 

circSD_subj_summary <- 
  circSD_subj_samples %>%
  group_by(subj, param) %>%
  median_qi(.width = c(.90, .95))
## Warning: unnest() has a new interface. See ?unnest for details.
## Try `df %>% unnest(c(.lower, .upper))`, with `mutate()` if needed

## Warning: unnest() has a new interface. See ?unnest for details.
## Try `df %>% unnest(c(.lower, .upper))`, with `mutate()` if needed
pMem_subj_samples <- 
  model_fit %>%
  spread_draws(b_theta_intercept, b_theta_stimulation, r_subj__theta[subj, term]) %>%
  ungroup() %>%
  pivot_wider(names_from = term, values_from = r_subj__theta, names_prefix = "offset_") %>%
  mutate(subj = subj,
            pMem_pre = inv_logit(b_theta_intercept + offset_Intercept),
            pMem_post = inv_logit(b_theta_intercept + offset_Intercept + b_theta_stimulation + offset_stimulation),
            pMem_ES = pMem_post - pMem_pre) %>%
  select(-c(b_theta_intercept, offset_Intercept, b_theta_stimulation, offset_stimulation)) %>%
  pivot_longer(contains("pMem"), names_to = "param", values_to = "value") 

pMem_subj_summary <- 
  pMem_subj_samples %>%
  group_by(subj, param) %>%
  median_qi(.width = c(.90, .95))
## Warning: unnest() has a new interface. See ?unnest for details.
## Try `df %>% unnest(c(.lower, .upper))`, with `mutate()` if needed

## Warning: unnest() has a new interface. See ?unnest for details.
## Try `df %>% unnest(c(.lower, .upper))`, with `mutate()` if needed
group_level_samples %>%
  select(-pred) %>%
  group_by(param) %>%
  median_qi(.width = c(.95))
## Warning: unnest() has a new interface. See ?unnest for details.
## Try `df %>% unnest(c(.lower, .upper))`, with `mutate()` if needed
## # A tibble: 6 x 7
##   param           mean  .lower .upper .width .point .interval
##   <chr>          <dbl>   <dbl>  <dbl>  <dbl> <chr>  <chr>    
## 1 circSD_ES    -8.59   -25.1   13.9     0.95 median qi       
## 2 circSD_post  38.9     25.0   68.5     0.95 median qi       
## 3 circSD_pre   48.0     33.6   69.1     0.95 median qi       
## 4 pMem_ES      -0.0425  -0.268  0.160   0.95 median qi       
## 5 pMem_post     0.575    0.284  0.793   0.95 median qi       
## 6 pMem_pre      0.616    0.377  0.809   0.95 median qi

group level posteriors

circSD_p1 <- group_level_samples %>% 
  filter(str_detect(param, "circSD")) %>%
  ggplot() + 
  # posterior dist + interval for group mean
  geom_halfeyeh(aes(y = param, x = mean), .width = c(.90, .95), position = position_nudge(y = 0.15)) + 
  # posterior predictive distribution for group means
  stat_intervalh(aes(y = param, x = pred), .width = c(.5, .8, .95)) +
  # posterior medians for each parameter estimate per subj
  geom_point(data = circSD_subj_summary, aes(y = param, x = value), size = 2) +
  # decorations
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  coord_cartesian(xlim=c(-50, 150)) +
  labs(subtitle = "circSD: group level mean posterior (median, 90%, 95% interval), \nsubject posterior medians, \ncondition predictive dist of subjects", 
       x = "circSD", 
       color = "interval")

# group level pMem pre, post and ES plot

pMem_p1 <- group_level_samples %>% 
  filter(str_detect(param, "pMem")) %>%
  ggplot() + 
  # posterior dist + interval for group mean
  geom_halfeyeh(aes(y = param, x = mean), .width = c(.90, .95), position = position_nudge(y = 0.15)) + 
  # posterior predictive distribution for group means
  stat_intervalh(aes(y = param, x = pred), .width = c(.5, .8, .95)) +
  # posterior medians for each parameter estimate per subj
  geom_point(data = pMem_subj_summary, aes(y = param, x = value), size = 2) +
  # decorations
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  #coord_cartesian(xlim=c(-50, 150)) +
  labs(subtitle = "pMem: group level mean posterior (median, 90%, 95% interval), \nsubject posterior medians, \ncondition predictive dist of subjects", 
       x = "pMem", 
       color = "interval")

plot_grid(circSD_p1, pMem_p1, align = "hv", ncol = 1)

circSD pre, post, ES

# circSD pre: group level posteriors and subject posteriors

circSD_p2 <- 
  ggplot() + 
  # plot group mean circSD_pre posterior and subject circSD_pre posterior
  geom_halfeyeh(data = rbind(
                              group_level_samples %>% 
                              filter(str_detect(param, "circSD_pre")) %>%
                              select(-pred, value = mean)
                            ,
                              circSD_subj_samples %>%
                              filter(str_detect(param, "circSD_pre")) %>%
                              unite(param, param, subj) )
                , aes(y = param, x = value), .width = c(.90, .95)) + 
  # plot pre condition group predictive distribution
  stat_intervalh(data = group_level_samples %>% filter(str_detect(param, "circSD_pre")),
                 aes(y = param , x = pred),
                 .width = c(.5, .8, .95),
                 position = position_nudge(y = -0.15)
                   ) +
  # show modeled subject circSD_pre posterior medians in the prediction band
  geom_point(data = circSD_subj_summary %>% filter(param == "circSD_pre"),
             aes(y = param, x = value), 
             size = 2, 
             position = position_nudge(y = -0.15)) + 
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  labs(subtitle = "circSD_pre: group level mean posterior (median, 90%, 95% interval), \nsubject posterior, \ncondition predictive dist of subjects",
       x = "circSD",
       color = "interval")



circSD_p3 <-   
  ggplot() + 
  # plot group mean circSD_post posterior and subject circSD_post posterior
  geom_halfeyeh(data = rbind(
                              group_level_samples %>% 
                              filter(str_detect(param, "circSD_post")) %>%
                              select(-pred, value = mean)
                            ,
                              circSD_subj_samples %>%
                              filter(str_detect(param, "circSD_post")) %>%
                              unite(param, param, subj) )
                , aes(y = param, x = value), .width = c(.90, .95)) + 
  # plot post condition group predictive distribution
  stat_intervalh(data = group_level_samples %>% filter(str_detect(param, "circSD_post")),
                 aes(y = param , x = pred),
                 .width = c(.5, .8, .95),
                 position = position_nudge(y = -0.15)
                   ) +
  # show modeled subject circSD_post posterior medians in the prediction band
  geom_point(data = circSD_subj_summary %>% filter(param == "circSD_post"),
             aes(y = param, x = value), 
             size = 2, 
             position = position_nudge(y =  -0.15)) + 
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  labs(subtitle = "circSD_post: group level mean posterior (median, 90%, 95% interval), \nsubject posterior, \ncondition predictive dist of subjects",
       x = "circSD",
       color = "interval")


circSD_p4 <- 
  ggplot() + 
  # plot group mean circSD_ES posterior and subject circSD_ES posterior
  geom_halfeyeh(data = rbind(
                              group_level_samples %>% 
                              filter(str_detect(param, "circSD_ES")) %>%
                              select(-pred, value = mean)
                            ,
                              circSD_subj_samples %>%
                              filter(str_detect(param, "circSD_ES")) %>%
                              unite(param, param, subj) )
                , aes(y = param, x = value), .width = c(.90, .95)) + 
  # plot ES  group predictive distribution
  stat_intervalh(data = group_level_samples %>% filter(str_detect(param, "circSD_ES")),
                 aes(y = param , x = pred),
                 .width = c(.5, .8, .95),
                 position = position_nudge(y = -0.15)
                   ) +
  # show modeled subject circSD_ES posterior medians in the prediction band
  geom_point(data = circSD_subj_summary %>% filter(param == "circSD_ES"),
             aes(y = param, x = value), 
             size = 2, 
             position = position_nudge(y =  -0.15)) + 
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  labs(subtitle = "circSD_ES: group level mean posterior (median, 90%, 95% interval), \nsubject posterior, \nES predictive dist of subjects",
       x = "detla circSD",
       color = "interval")


plot_grid(circSD_p2, circSD_p3, circSD_p4, ncol = 1, align = "hv")

pMem pre, post, ES

# pMem_ pre: group level posteriors and subject posteriors

pMem_p2 <- 
  ggplot() + 
  # plot group mean pMem__pre posterior and subject cpMem__pre posterior
  geom_halfeyeh(data = rbind(
                              group_level_samples %>% 
                              filter(str_detect(param, "pMem_pre")) %>%
                              select(-pred, value = mean)
                            ,
                              pMem_subj_samples %>%
                              filter(str_detect(param, "pMem_pre")) %>%
                              unite(param, param, subj) )
                , aes(y = param, x = value), .width = c(.90, .95)) + 
  # plot pre condition group predictive distribution
  stat_intervalh(data = group_level_samples %>% filter(str_detect(param, "pMem_pre")),
                 aes(y = param , x = pred),
                 .width = c(.5, .8, .95),
                 position = position_nudge(y = -0.15)
                   ) +
  # show modeled subject pMem_pre posterior medians in the prediction band
  geom_point(data = pMem_subj_summary %>% filter(param == "pMem_pre"),
             aes(y = param, x = value), 
             size = 2, 
             position = position_nudge(y = -0.15)) + 
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  labs(subtitle = "pMem_pre: group level mean posterior (median, 90%, 95% interval), \nsubject posterior, \ncondition predictive dist of subjects",
       x = "pMem",
       color = "interval")



pMem_p3 <-   
  ggplot() + 
  # plot group mean pMem_post posterior and subject pMem_post posterior
  geom_halfeyeh(data = rbind(
                              group_level_samples %>% 
                              filter(str_detect(param, "pMem_post")) %>%
                              select(-pred, value = mean)
                            ,
                              pMem_subj_samples %>%
                              filter(str_detect(param, "pMem_post")) %>%
                              unite(param, param, subj) )
                , aes(y = param, x = value), .width = c(.90, .95)) + 
  # plot post condition group predictive distribution
  stat_intervalh(data = group_level_samples %>% filter(str_detect(param, "pMem_post")),
                 aes(y = param , x = pred),
                 .width = c(.5, .8, .95),
                 position = position_nudge(y = -0.15)
                   ) +
  # show modeled subject pMem_post posterior medians in the prediction band
  geom_point(data = pMem_subj_summary %>% filter(param == "pMem_post"),
             aes(y = param, x = value), 
             size = 2, 
             position = position_nudge(y =  -0.15)) + 
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  labs(subtitle = "pMem_post: group level mean posterior (median, 90%, 95% interval), \nsubject posterior, \ncondition predictive dist of subjects",
       x = "pMem",
       color = "interval")


pMem_p4 <- 
  ggplot() + 
  # plot group mean pMem_ES posterior and subject pMem_ES posterior
  geom_halfeyeh(data = rbind(
                              group_level_samples %>% 
                              filter(str_detect(param, "pMem_ES")) %>%
                              select(-pred, value = mean)
                            ,
                              pMem_subj_samples %>%
                              filter(str_detect(param, "pMem_ES")) %>%
                              unite(param, param, subj) )
                , aes(y = param, x = value), .width = c(.90, .95)) + 
  # plot ES  group predictive distribution
  stat_intervalh(data = group_level_samples %>% filter(str_detect(param, "pMem_ES")),
                 aes(y = param , x = pred),
                 .width = c(.5, .8, .95),
                 position = position_nudge(y = -0.15)
                   ) +
  # show modeled subject pMem_ES posterior medians in the prediction band
  geom_point(data = pMem_subj_summary %>% filter(param == "pMem_ES"),
             aes(y = param, x = value), 
             size = 2, 
             position = position_nudge(y =  -0.15)) + 
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  labs(subtitle = "pMem_ES: group level mean posterior (median, 90%, 95% interval), \nsubject posterior, \nES predictive dist of subjects",
       x = "detla pMem",
       color = "interval")
plot_grid(pMem_p2, pMem_p3, pMem_p4, ncol = 1, align = "hv")

group level joint posteriors

group_level_samples %>%
  pivot_wider(id_cols = contains("."), names_from = param, values_from = mean) %>%
  select(-contains(".")) %>%
  mcmc_pairs(off_diag_fun = "hex")
## Warning: Only one chain in 'x'. This plot is more useful with multiple
## chains.

parameter recovery

prior_summary <- 
  sim_datasets %>% 
  transmute(circSD_pre = exp(alpha0_mu),
            circSD_post = exp(alpha0_mu + alphaD_mu),
            circSD_ES = circSD_post - circSD_pre,
            pMem_pre = exp(beta0_mu)/(exp(beta0_mu) + 1),
            pMem_post = exp(beta0_mu + betaD_mu)/(exp(beta0_mu + betaD_mu) + 1),
            pMem_ES = pMem_post - pMem_pre) %>%
  pivot_longer(everything(), names_to = "param", values_to = "value")

circSD_p1 <- group_level_samples %>% 
  filter(str_detect(param, "circSD")) %>%
  ggplot() + 
  # posterior dist + interval for group mean
  geom_halfeyeh(aes(y = param, x = mean), .width = c(.90, .95), position = position_nudge(y = 0.15)) + 
  # posterior medians for each parameter estimate per subj
  geom_point(data = prior_summary %>% filter(str_detect(param, "circSD")), aes(y = param, x = value, color = "red"), size = 2) +
  # decorations
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  coord_cartesian(xlim=c(-50, 150)) +
  labs(subtitle = "circSD: group level mean posterior (median, 90%, 95% interval), \ntrue param value = red dot", 
       x = "circSD")

# group level pMem pre, post and ES plot

pMem_p1 <- group_level_samples %>% 
  filter(str_detect(param, "pMem")) %>%
  ggplot() + 
  # posterior dist + interval for group mean
  geom_halfeyeh(aes(y = param, x = mean), .width = c(.90, .95), position = position_nudge(y = 0.15)) + 
  # posterior medians for each parameter estimate per subj
  geom_point(data = prior_summary %>% filter(str_detect(param, "pMem")), aes(y = param, x = value, color = "red"), size = 2) +
  # decorations
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  #coord_cartesian(xlim=c(-50, 150)) +
  labs(subtitle = "pMem: group level mean posterior (median, 90%, 95% interval), \ntrue param value = red dot", 
       x = "pMem")

plot_grid(circSD_p1, pMem_p1, align = "hv", ncol = 1)